
############################################ ################################################
######################################## FIGURE 4 ##########################################
############################################ ################################################

# Load necessary libraries using pacman
pacman::p_load(grf, foreign, readstata13, stargazer, dplyr, ggplot2, tidytext)

setwd("~/Dropbox/EJPR_replication/")

if (!dir.exists("plots")) {
  dir.create("plots")
}

dat = read_rds('~/Dropbox/EJPR_replication/causal_forests_gip_data.rds')

set.seed(123)

# Define all 
covariate_names <- c("women", "age_under32", "age_37_47", "age_52_62", "age_67_87",
                     "has_partner", "N_household_members","daily_internet", "CDU", "SPD",
                     "FDP", "Grune","Linke", "Other_party",
                     "secondary_school", "intermediary_school", "university_applied", "university_abitur", "other_education",
                     "in_vocational_training", "completed_vocational_training", "completed_academic_degree", "other_vocational_degree",
                     "full_time_employed", "part_time_employed", "non_employment", "education_related_status", "other_status", "retired")

# Create a new data frame that maps each covariate to its label
covariate_labels <- data.frame(covariate = covariate_names,
                               label = c("Women", "Age - under 32", "Age  - 37-47", "Age - 52-62", "Age - 67-87",
                                         "Married/Has partner", "Number of household members","Daily internet use", "CDU/CSU", "SPD",
                                         "FDP", "Die Grünen", "Die Linke", "Other party",
                                         "Secondary school", "Intermediary school", "University - applied", "University - abitur", "Other education",
                                         "In vocational training", "Completed vocational training", "Completed academic degree", "Other vocational degree",
                                         "Full-time employed", "Part-time employed", "Non-employment", "Educational status", "Other employment status", "Retired"))


outcome_variable_name = c("Trust_in_Gov")
treatment_variable_name = c("Speech")

# Combine all names
all_variables_names <- c(outcome_variable_name, treatment_variable_name, covariate_names)
df <- dat %>% dplyr::select(one_of(all_variables_names))

# Drop rows containing missing values
df <- na.omit(df)

# Rename variables
df <- df %>% dplyr::rename(Y=Trust_in_Gov,W=Speech)

# Converting all columns to numerical
df <- data.frame(lapply(df, function(x) as.numeric(as.character(x))))

# train
train_fraction <- 0.80  # Use train_fraction % of the dataset to train our models
n <- dim(df)[1]
train_idx <- sample.int(n, replace=F, size=floor(n*train_fraction))
df_train <- df[train_idx,]
df_test <- df[-train_idx,]

cf <- causal_forest(
  X = as.matrix(df_train[,covariate_names]),
  Y = df_train$Y,
  W = df_train$W,
  num.trees=4041)

### Predict point estimates and standard errors (training set, out-of-bag)
oob_pred      <- predict(cf, estimate.variance=TRUE)
oob_tauhat_cf <- oob_pred$predictions
oob_tauhat_cf_se <- sqrt(oob_pred$variance.estimates)

var_imp        <- c(variable_importance(cf)) 
names(var_imp) <- covariate_names
var_imp <- var_imp %>% sort(decreasing=TRUE)

num_tiles <- 4  # ntiles = CATE is above / below the median

df_train$cate  <- oob_tauhat_cf
df_train$ntile <- factor(ntile(oob_tauhat_cf, n=num_tiles))


# Standard model estimates by quantile
estimated_sample_ate <- 
  lm_robust(Y ~ ntile + ntile:W, data=df_train) %>% 
  tidy() %>% 
  dplyr::filter(stringr::str_detect(term, ":W"))

# AIPW estimates by quantile
estimated_aipw_ate <- 
  lapply(
    seq(num_tiles), function(w) 
      average_treatment_effect(cf, subset = df_train$ntile == w)
  ) %>% bind_rows

estimated_aipw_ate

combined_estimates <- 
  bind_rows(
    estimated_sample_ate %>% mutate(type = "lm_robust") %>% dplyr::select(-outcome, -df, -statistic, - p.value),
    estimated_aipw_ate %>% dplyr::rename(std.error=std.err) %>%
      mutate(
        type  = "aipw",
        term = estimated_sample_ate$term) %>%
      mutate(
        conf.low = estimate - 1.96*std.error,
        conf.high = estimate + 1.96*std.error)
  )

# Outputs
list(cf = cf,
     df_train = df_train, 
     X = as.matrix(df_train[,covariate_names]),
     oob_tauhat_cf = oob_tauhat_cf, 
     var_imp = var_imp, 
     ntile_estimates = combined_estimates)

fitted_vals <- function(var_of_interest, model = test){
  
  df_train <- model$df_train
  cf <- model$cf
  
  is_continuous <- (length(unique(df_train[var_of_interest][[1]])) > 5) # crude rule for determining continuity
  if(is_continuous) {
    x_grid <- quantile(df_train[var_of_interest][[1]], probs = seq(0, 1, length.out = 5))
  } else {
    x_grid <- sort(unique(df_train[var_of_interest][[1]]))
  }
  
  df_grid <-  setNames(data.frame(x_grid), var_of_interest)
  
  other_covariates <- covariate_names[!covariate_names %in% var_of_interest]
  df_median <- df_train %>% dplyr::select(all_of(other_covariates)) %>% summarise_all(median) 
  df_eval <- crossing(df_median, df_grid)
  
  pred <- predict(cf, newdata=df_eval[,covariate_names], estimate.variance=TRUE)
  df_eval$tauhat <- pred$predictions
  df_eval$se <- sqrt(pred$variance.estimates)
  
  # Change to factor so the plotted values are evenly spaced (e.g. logicals)
  df_eval %>% arrange(var_of_interest) %>%
    mutate(var_of_interest = as.factor(as.numeric(df_eval[var_of_interest][[1]])))
}


hat_matters <- data.frame(covariate = names(var_imp), value = var_imp)


hat_matters <- hat_matters %>% 
  mutate(Category = case_when(
    grepl("^spd|^cdu|^fdp|^linke|^grune|^other_party", covariate) ~ "Party affiliation",
    grepl("^secondary_school|^intermediary_school|^university_applied|^university_abitur|^other_education", covariate) ~ "Education",
    grepl("in_vocational_training|^completed_vocational_training|^completed_academic_degree|^other_vocational_degree", covariate) ~ "Profession",
    grepl("full_time_employed|^part_time_employed|^non_employment|^education_related_status|^other_status|^retired", covariate) ~ "Employment",
    grepl("^women|^age|^has_partner|^N_household_members|^daily_internet", covariate) ~ "Demographics",
    TRUE ~ "Party"
  ))

hat_matters$Category <- factor(hat_matters$Category, levels = c("Education", "Profession", "Employment", "Demographics", "Party"))

# Adjust the colors to black, gray, and red
hat_matters_labels <- left_join(hat_matters, covariate_labels, by = "covariate")

# Identify the three largest values
top_three_values <- head(hat_matters %>% arrange(desc(value)), 3)$covariate

# Create a new column for color assignment based on the top three values
hat_matters <- hat_matters %>%
  mutate(ColorCategory = ifelse(covariate %in% top_three_values, "Top 3", "Other"))

# Adjust the color mapping to differentiate top three and the rest
color_mapping <- c(
  "Top 3" = "red",
  "Other" = "gray60"
)

# Create the improved plot with top three in red and others in gray
hat_matters_labels <- left_join(hat_matters, covariate_labels, by = "covariate")

fig4 <- ggplot(hat_matters_labels, aes(value, label)) + 
  geom_point(aes(color = ColorCategory), size = 3, show.legend = FALSE) +  # Ensure no legend
  scale_color_manual(values = color_mapping, guide = "none") +  # Remove color guide explicitly
  scale_x_continuous(name = "Variable Importance") + 
  ggtitle("Heterogeneity in Political Trust") + 
  ylab("") +
  facet_grid(Category ~ ., scales = "free_y", space = "free") +  # Use facet_grid to create titles for each category
  theme(
    strip.text.y = element_text(angle = 0, hjust = 0, size = 16, face = "bold"),  # Increased size for category titles
    strip.placement = "inside",  # Keep strip labels inside plot area as titles
    legend.position = "none",  # Ensure the legend is off
    axis.text.y = element_text(size = 10),  # Larger y-axis text for clarity
    plot.title = element_text(hjust = 0.5, size = 16, face = "bold")  # Center title and adjust size
  ) +
  scale_y_reordered() +
  theme_minimal(base_size = 12) +
  theme(
    strip.background = element_blank()  # Remove strip background for a cleaner look
  )

# Display the plot
fig4


ggsave("plots/fig4.png", plot = fig4, width = 10, height = 8, dpi = 300)

